import json
import os

import numpy as np
import torch

from ModularUtils.FunctionsConstant import plot_lines
epochs=300
delta = 10
last_exp="SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Sep_20_2022-13_52"
kl_diff = {}
tvd_diff={}
dashed=[]
solid=[]

isCF=False

if isCF==False:
    dist_keys={'P(X1X2WYcolor|do_[])': '$P(X_1,X_2,W,Color)$',
         'P(X2WYcolor|doX1_0)': '$P(X_2,W,Color|do(X_1=0))$',
         'P(X2WYcolor|doX1_1)': '$P(X_2,W,Color|doX_1=1))$'}

    # dashed= list(dist_keys.values())
    print("tvd diffs")
    for dist in dist_keys:
        tvd_diff[dist_keys[dist]] = torch.load(last_exp + "/tvd/" + dist).detach().cpu().numpy()[0:300]
        kl_diff[dist_keys[dist]] = torch.load(last_exp + "/kl/" + dist).detach().cpu().numpy()[0:300]




    last_exp = "SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Sep_20_2022-17_15/"
    new_dist=(
    {'P(Ydigit1Ydigit2Ythick|do_[])':'$P(Y_1,Y_2,Thick)$',
              'P(Ydigit1Ydigit2Ythick|doX1_0)':'$P(Y_1,Y_2,Thick|do(X_1=0))$',
              'P(Ydigit1Ydigit2Ythick|doX1_1)': '$P(Y_1,Y_2,Thick|do(X_1=1))$',
              'P(X1X2WYdigit1Ydigit2YcolorYthick|do[])': '$P(X_1,X_2,W,Y_1,Y_2,Color,Thick)$',
              'P(X2WYdigit1Ydigit2YcolorYthick|doX1)': '$P(X_2,W,Y_1,Y_2,Color,Thick|do(X_1))$'}
        )

    dist_keys.update(new_dist)
    # solid = new_dist.values()
    # dashed= list(dist_keys.values())
    print("tvd diffs")
    for dist in new_dist:
        tvd_diff[dist_keys[dist]] = torch.load(last_exp + "/tvd/" + dist).detach().cpu().numpy()[-300:]
        kl_diff[dist_keys[dist]] = torch.load(last_exp + "/kl/" + dist).detach().cpu().numpy()[-300:]

else:
    new_cf_dist={'P(Ycolor|do(X1,X2),X1p, X2p)':'$P_{X_1, X_2}(Color|X_1\',X_2\')$'}
    last_exp = "SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Sep_20_2022-13_52/"
    for dist in new_cf_dist:
        tvd_diff[new_cf_dist[dist]] = torch.load(last_exp + "/tvd/" + dist).detach().cpu().numpy()[-300:]
        kl_diff[new_cf_dist[dist]] = torch.load(last_exp + "/kl/" + dist).detach().cpu().numpy()[-300:]

label_keys = tvd_diff.keys()

tvd_error, kl_error = {}, {}
new_tvd = {}
new_kl = {}
xaxis = []
for dist in tvd_diff:
    new_tvd[dist], new_kl[dist] = [], []
    tvd_error[dist], kl_error[dist] = [], []
    idx = 0
    while (idx + 1) * delta <= min(epochs, tvd_diff[dist].shape[0]):
        st, en = idx * delta, (idx + 1) * delta
        new_tvd[dist].append(np.mean(tvd_diff[dist][st: en]))
        new_kl[dist].append(np.mean(kl_diff[dist][st: en]))

        # tvd
        error = abs(tvd_diff[dist][idx * delta: (idx + 1) * delta] - new_tvd[dist][-1])
        tvd_error[dist].append(np.mean(error))

        # kl
        error = abs(kl_diff[dist][idx * delta: (idx + 1) * delta] - new_kl[dist][-1])
        kl_error[dist].append(np.mean(error))

        idx += 1

    xaxis = [i * delta for i in range(len(new_tvd[dist]))]

label_keys = tvd_diff.keys()
# label_keys=["P(X1,X2,W,Color)", "P(X2,W,Color|do(X1=0)", "P(X2,W,Color|do(X1=1)"]
# label_keys=["P(Y1,Y2,Thick)", "P(Y1,Y2,Thick|do(X1=0)", "P(Y1,Y2,Thick|do(X1=1)", "P(X1,X2,W,Y1,Y2,Color,Thick)", "P(X2,W,Y1,Y2,Color,Thick|do(X1)"]
plot_lines("Modular Training distribution convergence", "Total Variation Distance",
           list(new_tvd.values()), xaxis,
           list(label_keys), dashed, [], list(tvd_error.values()), save_plot=False,
           path=last_exp)

plot_lines("Modular Training distribution convergence", "KL Divergence",
           list(new_kl.values()), xaxis,
           list(label_keys), dashed, [], list(kl_error.values()), save_plot=False,
           path=last_exp)